import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset


class DatasetSplit(Dataset):
  """An abstract Dataset class wrapped around Pytorch Dataset class."""
  
  def __init__(self, dataset, idxs):
    self.dataset = dataset
    self.idxs = [int(i) for i in idxs]
  
  def __len__(self):
    return len(self.idxs)
  
  def __getitem__(self, item):
    image, label = self.dataset[self.idxs[item]]
    return torch.tensor(image), torch.tensor(label)


class LocalUpdate(object):
  def __init__(self, args, dataset, idxs, logger):
    self.args = args
    self.logger = logger
    self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
    self.device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
    # Default criterion set to NLL loss function
    self.criterion = nn.NLLLoss().to(self.device)
  
  def train_val_test(self, dataset, idxs):
    """
    Returns train, validation and test dataloaders for a given dataset and user indexes.
    """
    # split indexes for train, validation, and test (80, 10, 10)
    idxs_train = idxs[:int(0.8 * len(idxs))]
    idxs_val = idxs[int(0.8 * len(idxs)):int(0.9 * len(idxs))]
    idxs_test = idxs[int(0.9 * len(idxs)):]
    
    trainloader = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
    validloader = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val) / 10), shuffle=False)
    testloader = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test) / 10), shuffle=False)
    return trainloader, validloader, testloader
  
  def update_weights(self, model, global_round):
    # Set mode to train model
    model.train()
    epoch_loss = []
    
    # Set optimizer for the local updates
    if self.args.optimizer == 'sgd':
      optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.5)
    elif self.args.optimizer == 'adam':
      optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=1e-4)
    
    for local_epoch in range(self.args.local_ep):
      batch_loss = []
      for batch_idx, (images, labels) in enumerate(self.trainloader):
        images, labels = images.to(self.device), labels.to(self.device)
        
        model.zero_grad()
        log_probs = model(images)
        loss = self.criterion(log_probs, labels)
        loss.backward()
        optimizer.step()
        
        if self.args.verbose and (batch_idx % 10 == 0):
          print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)] \t Loss: {:.6f}'.format(
            global_round, local_epoch, batch_idx * len(images),
            len(self.trainloader.dataset), 100. * batch_idx / len(self.trainloader), loss.item()))
        self.logger.add_scalar('loss', loss.item())
        batch_loss.append(loss.item())
      epoch_loss.append(sum(batch_loss) / len(batch_loss))
    
    return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
  
  def inference(self, model):
    """ Returns the inference accuracy and loss."""
    
    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    
    for batch_idx, (images, labels) in enumerate(self.testloader):
      images, labels = images.to(self.device), labels.to(self.device)
      
      # Inference
      outputs = model(images)
      batch_loss = self.criterion(outputs, labels)
      loss += batch_loss.item()
      
      # Prediction
      _, pred_labels = torch.max(outputs, 1)
      pred_labels = pred_labels.view(-1)
      correct += torch.sum(torch.eq(pred_labels, labels)).item()
      total += len(labels)
    
    accuracy = correct / total
    return accuracy, loss


def test_inference(args, model, test_dataset):
  """ Returns the test accuracy and loss. """
  
  model.eval()
  loss, total, correct = 0.0, 0.0, 0.0
  
  device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
  
  criterion = nn.NLLLoss().to(device)
  testloader = DataLoader(test_dataset, batch_size=128, shuffle=False)
  
  for batch_idx, (images, labels) in enumerate(testloader):
    images, labels = images.to(device), labels.to(device)
    
    # Inference
    outputs = model(images)
    batch_loss = criterion(outputs, labels)
    loss += batch_loss.item()
    
    # Prediction
    _, pred_labels = torch.max(outputs, 1)
    pred_labels = pred_labels.view(-1)
    correct += torch.sum(torch.eq(pred_labels, labels)).item()
    total += len(labels)
  
  accuracy = correct / total
  return accuracy, loss
